-
Notifications
You must be signed in to change notification settings - Fork 237
Top-K KL Divergence loss #747
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually. Contributors can view more details about this message here. |
06d057f to
e7d33a7
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #747 +/- ##
=======================================
Coverage 74.62% 74.62%
=======================================
Files 192 192
Lines 18989 18989
=======================================
Hits 14171 14171
Misses 4818 4818 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
e7a34bf to
3093b8a
Compare
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
Signed-off-by: Asha Anoosheh <[email protected]>
335844f to
b128707
Compare
📝 WalkthroughWalkthroughAdded Top-K logit KL-divergence support to Megatron-based distillation by introducing TopKLogitsKLLoss class and logit_kl_topk configuration parameter. The implementation includes distributed tensor model-parallel aware top-K gathering and KL computation, with updated LogitsKLLoss for multi-TP handling. Changes
Sequence Diagram(s)sequenceDiagram
actor Student as Student Model
participant LSK as Top-K KL Loss
participant TP as TP Ranks
actor Teacher as Teacher Model
Student->>LSK: predictions (logits)
Teacher->>LSK: targets (logits)
LSK->>LSK: Extract local top-K logits per rank
LSK->>TP: Gather top-K indices from all TP ranks
TP->>LSK: Return global top-K across TP
LSK->>LSK: Collect top-K logits from all ranks
LSK->>LSK: Compute log probabilities on global top-K
LSK->>LSK: Compute KL divergence with temperature scaling
LSK-->>Student: KL loss value
Estimated code review effort🎯 4 (Complex) | ⏱️ ~75 minutes 🚥 Pre-merge checks | ✅ 3✅ Passed checks (3 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
modelopt/torch/distill/plugins/megatron.py (1)
73-79: Consider validatinglogit_kl_topk > 0when set.The
__post_init__validates other parameters but doesn't check thatlogit_kl_topkis positive when notNone. A value of 0 or negative would cause issues downstream inTopKLogitsKLLoss.Suggested validation
def __post_init__(self): assert len(self.logit_layers) == 2, f"{self.logit_layers=}" assert all(len(pair) in (2, 3) for pair in self.intermediate_layer_pairs), ( f"{self.intermediate_layer_pairs=}" ) assert self.kd_loss_scale > 0, f"{self.kd_loss_scale=}" assert self.logit_kl_temperature > 0, f"{self.logit_kl_temperature=}" + assert self.logit_kl_topk is None or self.logit_kl_topk > 0, f"{self.logit_kl_topk=}"tests/gpu/torch/distill/plugins/test_distill_megatron.py (1)
38-127: Solid integration test for LogitsKLLoss.The test covers the essential flow: model creation, distillation setup, forward pass, loss computation, and backward pass.
Consider adding gradient verification for more robust testing:
# After backward pass, verify gradients exist for name, param in distillation_model.named_parameters(): if param.requires_grad and param.grad is not None: assert param.grad.abs().sum() > 0, f"Zero gradients for {name}" break # At least one param has non-zero gradient
📜 Review details
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
modelopt/torch/distill/plugins/megatron.pytests/gpu/torch/distill/plugins/test_distill_megatron.py
🧰 Additional context used
🧬 Code graph analysis (2)
tests/gpu/torch/distill/plugins/test_distill_megatron.py (6)
tests/_test_utils/import_helper.py (1)
skip_if_no_megatron(46-77)tests/_test_utils/torch/distributed/utils.py (1)
spawn_multiprocess_job(51-65)tests/_test_utils/torch/megatron/models.py (1)
get_mcore_gpt_model(125-244)tests/_test_utils/torch/megatron/utils.py (1)
run_mcore_inference_with_dummy_input(122-129)modelopt/torch/distill/plugins/megatron.py (2)
DistillationConfig(52-95)adjust_distillation_model_for_mcore(558-616)modelopt/torch/distill/distillation_model.py (2)
loss_balancer(134-136)compute_kd_loss(237-288)
modelopt/torch/distill/plugins/megatron.py (1)
modelopt/torch/distill/distillation_model.py (1)
forward(209-235)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
🔇 Additional comments (7)
modelopt/torch/distill/plugins/megatron.py (4)
129-137: LGTM!The conditional instantiation logic is clear and correctly passes the appropriate parameters to each loss class.
346-370: LGTM!The use of
dist_nn.functional.all_reducefor computing global softmax denominators correctly preserves gradients through the distributed operation. The comment on lines 347-348 clearly explains the rationale.
373-458: Well-structured Top-K implementation with proper TP handling.The implementation correctly:
- Preserves gradients through
dist_nn.functional.all_gather- Guards against
top_kexceeding total vocabulary size- Handles edge cases where local vocabulary size is smaller than
top_k- Avoids unnecessary TP reduction since all ranks compute the same global top-K
The docstring appropriately warns users about memory/communication implications for large K values.
490-493: Reasonable change to accommodateTopKLogitsKLLoss.The
"Logits" in _keycheck is necessary sinceTopKLogitsKLLossdoesn't start with "Logits". Be aware this could match unintended keys if future loss classes contain "Logits" anywhere in the name.tests/gpu/torch/distill/plugins/test_distill_megatron.py (3)
19-33: LGTM!The skip guard pattern correctly prevents test failures when Megatron or required dependencies are unavailable.
129-217: Test structure is appropriate.The larger
vocab_size=128correctly provides enough vocabulary entries for meaningful top-k testing withtop_k=5.The duplication with
_test_logits_kl_losscould be reduced by extracting common setup into a helper, but the explicit test isolation is acceptable for test readability.
220-237: Consider handling edge cases for device count.The tests assume
torch.cuda.device_count() >= 2for meaningful tensor parallelism testing. If only one GPU is available,size=1would run without TP, which may not exercise the distributed code paths.Consider adding a skip condition:
import pytest def test_logits_kl_loss(): """Test LogitsKLLoss with TP parallelism.""" if torch.cuda.device_count() < 2: pytest.skip("Need at least 2 GPUs for TP testing") set_seed(SEED) spawn_multiprocess_job( size=torch.cuda.device_count(), job=_test_logits_kl_loss, backend="nccl", )Alternatively, verify that the existing helper functions handle this appropriately in the test infrastructure.
| # We can't use standard all_reduce function here since the computation | ||
| # that follows it isn't identical across TP ranks. | ||
| denom_teacher = torch.sum(torch.exp(output_teacher), dim=-1, keepdim=True) | ||
| denom_teacher = dist_nn.functional.all_reduce(denom_teacher, group=tp_group) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is torch.distributed.nn.functional.all_reduce same as torch.distributed.all_reduce?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's the premium edition which allows gradient backprop through it
## What does this PR do? **Type of change:** New feature **Overview:** Writes a new KLDiv Logits loss which only uses top-k vocab values ## Usage <!-- You can potentially add a usage example below. --> ```python # Add a code snippet demonstrating how to use this ``` ## Testing <!-- Mention how have you tested your change if applicable. --> ## Before your PR is "*Ready for review*" <!-- If you haven't finished some of the above items you can still open `Draft` PR. --> - **Make sure you read and follow [Contributor guidelines](https://github.com/NVIDIA/Model-Optimizer/blob/main/CONTRIBUTING.md)** and your commits are signed. - **Is this change backward compatible?**: Yes/No <!--- If No, explain why. --> - **Did you write any new necessary tests?**: Yes/No - **Did you add or update any necessary documentation?**: Yes/No - **Did you update [Changelog](https://github.com/NVIDIA/Model-Optimizer/blob/main/CHANGELOG.rst)?**: Yes/No <!--- Only for new features, API changes, critical bug fixes or bw breaking changes. --> ## Additional Information <!-- E.g. related issue. --> <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit * **New Features** * Added Top-K logit filtering capability for knowledge distillation workflows, enabling selective focus on high-probability tokens. * **Improvements** * Enhanced distributed tensor model-parallel operations with improved awareness for gradient computation and reduction. * Simplified legacy distributed operation constructs. * **Tests** * Introduced comprehensive test coverage for Megatron-based distillation, validating both standard and Top-K filtering variants. <sub>✏️ Tip: You can customize this high-level summary in your review settings.</sub> <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Asha Anoosheh <[email protected]> Signed-off-by: Jingyu Xin <[email protected]>
What does this PR do?
Type of change: New feature
Overview: Writes a new KLDiv Logits loss which only uses top-k vocab values
Usage
# Add a code snippet demonstrating how to use thisTesting
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Improvements
Tests
✏️ Tip: You can customize this high-level summary in your review settings.